{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "be350003",
   "metadata": {},
   "source": [
    "# Advanced Q&A with LlamaIndex\n",
    "\n",
    "This notebook demonstrates how to use [LlamaIndex](https://docs.llamaindex.ai/en/stable/) to build a more complex retrieval for a chatbot. \n",
    "\n",
    "The retrieval method shown in this notebook works well for code documentation; it retrieves more contiguous document blocks that preserve both code snippets and explanations of code.  \n",
    "\n",
    "<div class=\"alert alert-block alert-info\">\n",
    "    \n",
    "⚠️ There are many node parsing and retrieval techniques supported in LlamaIndex and this notebook just shows how two of these techniques, [HierarchialNodeParser](https://docs.llamaindex.ai/en/stable/api_reference/service_context/node_parser.html) and [AutoMergingRetriever](https://docs.llamaindex.ai/en/stable/examples/retrievers/auto_merging_retriever.html), can be useful for chatting with code documentation. \n",
    "</div>\n",
    "\n",
    "In this demo, we'll use the [`llama_docs_bot`](https://github.com/run-llama/llama_docs_bot/tree/main) GitHub repository as our sample documentation to query. This repository contains the content for a development series with LlamaIndex covering the following topics: \n",
    "- LLMs\n",
    "- Nodes and documents\n",
    "- Evaluation\n",
    "- Embeddings\n",
    "- Retrieval"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "547a35c9",
   "metadata": {},
   "source": [
    "### Step 1: Prerequisite Setup\n",
    "By now you should be familiar with these steps:\n",
    "1. Create an LLM client.\n",
    "2. Set the prompt template for the LLM.\n",
    "3. Download embeddings.\n",
    "4. Set the service context.\n",
    "5. Split the text\n",
    "\n",
    "<div class=\"alert alert-block alert-warning\">\n",
    "    \n",
    "<b>WARNING!</b> Be sure to replace `server_url` with the address and port that Triton is running on. \n",
    "\n",
    "</div>\n",
    "\n",
    "Use the address and port that the Triton is available on; for example `localhost:8001`. **If you are running this notebook as part of the generative ai workflow, you can use the existing url."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df278873",
   "metadata": {},
   "outputs": [],
   "source": [
    "from triton_trt_llm import TensorRTLLM\n",
    "from llama_index.llms import LangChainLLM\n",
    "trtllm =TensorRTLLM(server_url =\"llm:8001\", model_name=\"ensemble\", tokens=500)\n",
    "llm = LangChainLLM(llm=trtllm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80b01c15",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index import Prompt\n",
    "\n",
    "LLAMA_PROMPT_TEMPLATE = (\n",
    " \"<s>[INST] <<SYS>>\"\n",
    " \"Use the following context to answer the user's question. If you don't know the answer, just say that you don't know, don't try to make up an answer.\"\n",
    " \"<</SYS>>\"\n",
    " \"<s>[INST] Context: {context_str} Question: {query_str} Only return the helpful answer below and nothing else. Helpful answer:[/INST]\"\n",
    ")\n",
    "\n",
    "qa_template = Prompt(LLAMA_PROMPT_TEMPLATE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30d8fb1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.embeddings import HuggingFaceEmbeddings\n",
    "from llama_index.embeddings import LangchainEmbedding\n",
    "from llama_index import ServiceContext, set_global_service_context\n",
    "\n",
    "model_kwargs = {\"device\": \"cpu\"}\n",
    "encode_kwargs = {\"normalize_embeddings\": False}\n",
    "hf_embeddings = HuggingFaceEmbeddings(\n",
    "    model_name=\"intfloat/e5-large-v2\",\n",
    "    model_kwargs=model_kwargs,\n",
    "    encode_kwargs=encode_kwargs,\n",
    ")\n",
    "# Load in a specific embedding model\n",
    "embed_model = LangchainEmbedding(hf_embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c47d0437",
   "metadata": {},
   "outputs": [],
   "source": [
    "service_context = ServiceContext.from_defaults(\n",
    "  llm=llm,\n",
    "  embed_model=embed_model\n",
    ")\n",
    "set_global_service_context(service_context)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb9b7488",
   "metadata": {},
   "source": [
    "When splitting the text, we split it into a parent node of 1024 tokens and two children nodes of 510 tokens. Our leaf nodes' maximum size is 512 tokens, so we need to make the largest leaves that can exist under 512 tokens. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "797170a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.text_splitter import TokenTextSplitter\n",
    "text_splitter_ids = [\"1024\", \"510\"]\n",
    "text_splitter_map = {}\n",
    "for ids in text_splitter_ids:\n",
    "    text_splitter_map[ids] = TokenTextSplitter(\n",
    "        chunk_size=int(ids),\n",
    "        chunk_overlap=200\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1395c0e",
   "metadata": {},
   "source": [
    "### Step 2: Clone the Llama Docs Bot Repo \n",
    "This repository will be our sample documentation that we chat with. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbf2a3ae",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!git clone https://github.com/run-llama/llama_docs_bot.git"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e977334a",
   "metadata": {},
   "source": [
    "### Step 3: Define Document Loading and Node Parsing Function\n",
    "\n",
    "Assuming hierarchical node parsing is set to true, this function:\n",
    "- Parses each directory into a single giant document\n",
    "- Chunks the document into a hierarchy of nodes with a top-level chunk size (1024) and children chunks that are smaller (aka **hierarchical node parsing**)\n",
    "   ```\n",
    "         1024\n",
    "      /--------\\\n",
    "  1024//2     1024//2\n",
    "\n",
    "   ```\n",
    "\n",
    "#### Hierarchical Node Parser\n",
    "The novel part of this step is using LlamaIndex's [**Hierarchical Node Parser**](https://docs.llamaindex.ai/en/stable/api/llama_index.core.node_parser.HierarchicalNodeParser.html#llama_index.core.node_parser.HierarchicalNodeParser). This parses nodes into several chunk sizes. \n",
    "\n",
    "During retrieval, if a majority of chunks are retrieved that have the same parent chunk, the larger parent chunk is returned instead of the smaller chunks.\n",
    "\n",
    "#### Simple Node Parser\n",
    "If hierarchical parsing is false, a simple node structure is used and returned."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cff5b264",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index import SimpleDirectoryReader, Document\n",
    "from llama_index.node_parser import HierarchicalNodeParser, SimpleNodeParser, get_leaf_nodes\n",
    "from llama_index.schema import MetadataMode\n",
    "from llama_docs_bot.llama_docs_bot.markdown_docs_reader import MarkdownDocsReader\n",
    "\n",
    "# This function takes in a directory of files, puts them in a giant document, and parses and returns them as:\n",
    "# - a hierarchical node structure if it's a hierarchical implementation\n",
    "# - a simple node structure if it's a non-hierarchial implementation\n",
    "def load_markdown_docs(filepath, hierarchical=True):\n",
    "    \"\"\"Load markdown docs from a directory, excluding all other file types.\"\"\"\n",
    "    loader = SimpleDirectoryReader(\n",
    "        input_dir=filepath,\n",
    "        required_exts=[\".md\"],\n",
    "        file_extractor={\".md\": MarkdownDocsReader()},\n",
    "        recursive=True\n",
    "    )\n",
    "\n",
    "    documents = loader.load_data()\n",
    "\n",
    "    if hierarchical:\n",
    "        # combine all documents into one\n",
    "        documents = [\n",
    "            Document(text=\"\\n\\n\".join(\n",
    "                    document.get_content(metadata_mode=MetadataMode.ALL)\n",
    "                    for document in documents\n",
    "                )\n",
    "            )\n",
    "        ]\n",
    "\n",
    "        # chunk into 3 levels\n",
    "        # majority means 2/3 are retrieved before using the parent\n",
    "        large_chunk_size = 1536\n",
    "        node_parser = HierarchicalNodeParser.from_defaults(node_parser_ids=text_splitter_ids, node_parser_map=text_splitter_map)\n",
    "\n",
    "        nodes = node_parser.get_nodes_from_documents(documents)\n",
    "        return nodes, get_leaf_nodes(nodes)\n",
    "    ########## This is NOT a hierarchical parser for demonstration purposes later in the notebook ##########\n",
    "    else:\n",
    "        node_parser = SimpleNodeParser.from_defaults()\n",
    "        nodes = node_parser.get_nodes_from_documents(documents)\n",
    "        return nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3905a5b6",
   "metadata": {},
   "source": [
    "### Step 4: Load and Parse Documents with Node Parser \n",
    "\n",
    "First, we define all of the documentation directories we want to pull from. \n",
    "\n",
    "Next, we load the documentation and store parent nodes in a `SimpleDocumentStore` and leaf nodes in a `VectorStoreIndex`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6311a28d",
   "metadata": {},
   "outputs": [],
   "source": [
    "docs_directories = {\n",
    "    \"./llama_docs_bot/docs/community\": \"Useful for information on community integrations with other libraries, vector dbs, and frameworks.\",\n",
    "    \"./llama_docs_bot/docs/core_modules/agent_modules\": \"Useful for information on data agents and tools for data agents.\",\n",
    "    \"./llama_docs_bot/docs/core_modules/data_modules\": \"Useful for information on data, storage, indexing, and data processing modules.\",\n",
    "    \"./llama_docs_bot/docs/core_modules/model_modules\": \"Useful for information on LLMs, embedding models, and prompts.\",\n",
    "    \"./llama_docs_bot/docs/core_modules/query_modules\": \"Useful for information on various query engines and retrievers, and anything related to querying data.\",\n",
    "    \"./llama_docs_bot/docs/core_modules/supporting_modules\": \"Useful for information on supporting modules, like callbacks, evaluators, and other supporting modules.\",\n",
    "    \"./llama_docs_bot/docs/getting_started\": \"Useful for information on getting started with LlamaIndex.\",\n",
    "    \"./llama_docs_bot/docs/development\": \"Useful for information on contributing to LlamaIndex development.\",\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ef673e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index import VectorStoreIndex,StorageContext, load_index_from_storage\n",
    "from llama_index.query_engine import RetrieverQueryEngine\n",
    "\n",
    "from llama_index.tools import QueryEngineTool, ToolMetadata\n",
    "from llama_index.storage.docstore import SimpleDocumentStore\n",
    "import os\n",
    "import time\n",
    "\n",
    "start_time = time.time()\n",
    "for directory, description in docs_directories.items():\n",
    "    nodes, leaf_nodes = load_markdown_docs(directory, hierarchical=True)\n",
    "\n",
    "    docstore = SimpleDocumentStore()\n",
    "    docstore.add_documents(nodes)\n",
    "    storage_context = StorageContext.from_defaults(docstore=docstore)\n",
    "\n",
    "    index = VectorStoreIndex(leaf_nodes, storage_context=storage_context)\n",
    "    index.storage_context.persist(persist_dir=f\"./data_{os.path.basename(directory)}\")\n",
    "\n",
    "print(f\"--- {time.time() - start_time} seconds ---\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "511e1c67",
   "metadata": {},
   "source": [
    "### Step 5: Define Custom Node Post-Processor\n",
    "\n",
    "A [**Node PostProcessor**](https://docs.llamaindex.ai/en/stable/module_guides/querying/node_postprocessors/node_postprocessors.html#node-postprocessor-modules) takes a list of retrieved nodes and transforms them (filtering, replacement, etc). \n",
    "\n",
    "This custom node post-processor provides a simple approach to approximate token counts and returns the most nodes that fit within the token count (2500 tokens). Nodes are already sorted, so the most similar ones are returned first. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "546442b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Callable, Optional\n",
    "\n",
    "from llama_index.utils import globals_helper, get_tokenizer\n",
    "from llama_index.schema import MetadataMode\n",
    "\n",
    "class LimitRetrievedNodesLength:\n",
    "\n",
    "    def __init__(self, limit: int = 2500, tokenizer: Optional[Callable] = None):\n",
    "        self._tokenizer = tokenizer or get_tokenizer()\n",
    "\n",
    "        self.limit = limit\n",
    "\n",
    "    def postprocess_nodes(self, nodes, query_bundle):\n",
    "        included_nodes = []\n",
    "        current_length = 0\n",
    "\n",
    "        for node in nodes:\n",
    "            current_length += len(self._tokenizer(node.node.get_content(metadata_mode=MetadataMode.LLM)))\n",
    "            if current_length > self.limit:\n",
    "                break\n",
    "            included_nodes.append(node)\n",
    "\n",
    "        return included_nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fbddf64",
   "metadata": {},
   "source": [
    "### Step 5: Build the Retriever and Query Engine\n",
    "\n",
    "#### AutoMergingRetriever\n",
    "The [`AutoMergingRetriever`](https://docs.llamaindex.ai/en/stable/examples/retrievers/auto_merging_retriever.html) takes in a set of leaf nodes and recursively merges subsets of leaf nodes that reference a parent node beyond a given threshold. This allows for a consolidation of potentially disparate, smaller contexts into a larger context that may help synthesize disparate information. \n",
    "\n",
    "#### Query Engine\n",
    "A query engine is an object that takes in a query and returns a response.\n",
    "\n",
    "It may contain the following components:\n",
    "- **Retriever**: Given a query, retrieves relevant nodes.\n",
    "    - This example uses an `AutoMergingRetriever` if it's a hierarchial implementation.\n",
    "      *This replaces the retrieved nodes with the larger parent chunk*. \n",
    "- **Node PostProcessor**: Takes a list of retrieved nodes and transforms them (filtering, replacement, etc.)\n",
    "    - This example uses a post-processor that filters the retrieved nodes to a limited length. \n",
    "- **Response Synthesizer**: Takes a list of relevant nodes and synthesizes a response with an LLM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a76d6820",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.retrievers import AutoMergingRetriever\n",
    "from llama_index.query_engine import RetrieverQueryEngine\n",
    "\n",
    "retriever = AutoMergingRetriever(\n",
    "        index.as_retriever(similarity_top_k=12),\n",
    "        storage_context=storage_context\n",
    "    )\n",
    "\n",
    "query_engine = RetrieverQueryEngine.from_args(\n",
    "    retriever,\n",
    "    text_qa_template=qa_template,\n",
    "    node_postprocessors=[LimitRetrievedNodesLength(limit=2500)],\n",
    "    streaming=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b294862c",
   "metadata": {},
   "source": [
    "### Step 6: Stream Response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "695a42a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"How do I setup a weaviate vector db? Give me a code sample please.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b319f0c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "start_time = time.time()\n",
    "response = query_engine.query(query)\n",
    "response.print_response_stream()\n",
    "print(f\"\\n--- {time.time() - start_time} seconds ---\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ec1e497",
   "metadata": {},
   "source": [
    "To clear out cached data run:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9892b18a",
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf data_*"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
